

include("../src/qmcmary.jl")
using ..qmcmary

using LinearAlgebra
using Random


function sgn_scratch(ss::ScrollSVD{T}) where T
    siz = size(ss.B[end])
    ptr = findfirst(ss.L)
    if isnothing(ptr)
        VL = Diagonal(ones(siz[1]))
        DL = VL
        UL = VL
        UR, DR, VR = ss.F[end].U, Diagonal(ss.F[end].S), ss.F[end].Vt
    elseif ptr == 1
        VL, DL, UL = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
        UR = Diagonal(ones(siz[1]))
        DR = UR
        VR = UR
        ##上面这个数值稳定性会突然出问题
        ##用下面这个
        ##VL DL UL UR=I DR=I VR=I -> I I I UR=VL DR=DL VR=UL
        #VL = Diagonal(ones(siz[1]))
        #DL = VL
        #UL = VL
        #UR, DR, VR = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
    else
        VL, DL, UL = ss.F[ptr].U, Diagonal(ss.F[ptr].S), ss.F[ptr].Vt
        UR, DR, VR = ss.F[ptr-1].U, Diagonal(ss.F[ptr-1].S), ss.F[ptr-1].Vt
    end
    #gtt = inv(Diagonal(ones(siz[1]))+UR*DR*VR*VL*DL*UL)
    #M = inv(UL*UR) + DR*(VR*VL)*DL
    #Fm = svd(M)
    #gtt = inv(Fm.Vt*UL)*inv(Diagonal(Fm.S))*inv(UR*Fm.U)
    #
    DLS = Diagonal(ones(Float64, siz[1]))
    DLB = Diagonal(ones(Float64, siz[1]))
    DRS = Diagonal(ones(Float64, siz[1]))
    DRB = Diagonal(ones(Float64, siz[1]))
    for i in Base.OneTo(siz[1])
        if DL[i, i] > 1.0
            DLB[i, i] = DL[i, i]
            DLS[i, i] = 1.0
        else
            DLS[i, i] = DL[i, i]
            DLB[i, i] = 1.0
        end
        if DR[i, i] > 1.0
            DRB[i, i] = DR[i, i]
            DRS[i, i] = 1.0
        else
            DRS[i, i] = DR[i, i]
            DRB[i, i] = 1.0
        end
    end
    #
    M = inv(DRB)*adjoint(UL*UR)*inv(DLB) + DRS*(VR*VL)*DLS
    Fm = svd(M, alg=LinearAlgebra.QRIteration())
    #ML = adjoint(UL)*inv(DLB)*adjoint(Fm.Vt)
    #MR = adjoint(Fm.U)*inv(DRB)*adjoint(UR)
    #gtt = ML*inv(Diagonal(Fm.S))*MR
    #增加计算phase
    sgn = det(Fm.Vt*UL)*det(UR*Fm.U)
    sgn = log(abs(sgn))
    for sval in Fm.S
        sgn += log(sval)
    end
    for sval in diag(DLB)
        sgn += log(sval)
    end
    for sval in diag(DRB)
        sgn += log(sval)
    end
    return sgn
end


function run()
    ehk = [1 0; 0 1]
    Ui = 2.0
    dt = 0.1
    b1 = bmat_IsingADX(ehk, Val(1), Ui, 0.0, 0.0, 1.0, dt)
    b2 = bmat_Gauge1ADX(ehk, Val(1), Ui, 2acosh(exp(Ui*dt/2))/Ui/dt, 0.0, 0.5, 0.0, dt)
    println(b1)
    println(b2)
    b1 = bmat_IsingADX(ehk, Val(-1), Ui, 0.0, 0.0, 1.0, dt)
    b2 = bmat_Gauge1ADX(ehk, Val(-1), Ui, 2acosh(exp(Ui*dt/2))/Ui/dt, 0.0, 0.5, 0.0, dt)
    println(b1)
    println(b2)
    #
    d1 = Δmat_IsingND(1, 1, 1, -1, 2.0, 0.0, 0.0, 1.0, dt)
    d2 = Δmat_Gauge1ND(1, 1, 1, -1, 2.0, 2acosh(exp(Ui*dt/2))/Ui/dt, 0.0, 0.5, 0.0, dt)
    println(d1)
    println(d2)
    d1 = Δmat_IsingND(1, 1, -1, 1, 2.0, 0.0, 0.0, 1.0, dt)
    d2 = Δmat_Gauge1ND(1, 1, -1, 1, 2.0, 2acosh(exp(Ui*dt/2))/Ui/dt, 0.0, 0.5, 0.0, dt)
    println(d1)
    println(d2)
    #
    ehk = rand(4, 4)
    ehk = ehk + transpose(ehk)
    Ui = 2.0
    dt = 0.1
    cfg = 2(round.(rand(2)) .- 0.5)
    b1 = bmat_IsingND(ehk, cfg, Ui*ones(2), 0.0*ones(2), 0.0*ones(2), 1.0*ones(2), dt)
    b2 = bmat_Gauge1ND(ehk, cfg, Ui*ones(2), 2acosh(exp(Ui*dt/2))/Ui/dt*ones(2), 0.0*ones(2), 0.5*ones(2), zeros(2), dt)
    println(b1)
    println(b2)
end



function run2()
    L = 4
    Nx = L^2
    dt = 0.1
    #
    tp = 0.25
    hk = lattice_tprim_square(ComplexF64, L, -1.0+0.0im, tp+0.0im)
    #
    Ui = 4*ones(Nx)
    rax = @. 2acosh(exp(Ui*dt/2))/Ui/dt
    rax += rand(Nx)
    iax = rand(Nx)#zeros(Nx)
    rbx = 0.5*rand(Nx)
    ibx = rand(Nx)
    #
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    hscfg[1, 2] = 1
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats1, ss = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    #
    ni = 2
    #
    tapemap = pbpa_calc_tapemap(sp, rax, iax, rbx, ibx)
    adf_r, adf_i, idf_r, idf_i, bdf_r, bdf_i, ibdf_r, ibdf_i = pbpa_calc_xi(sp, hscfg[:, ni], ni, rax[ni], iax[ni], rbx[ni], ibx[ni]; tapemap=tapemap)
    #更改位置i上的a
    rax[ni] += 0.001
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats2, ss = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    #
    rdiff = real(allbmats2[1]) - real(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("r1", rdiff[ni, :]/0.001)
    println("adf", adf_r[1, 1, :])
    println("r dn",rdiff[ni+Nx, :]/0.001)
    println("adf dn", adf_r[1, 2, :])
    #
    idiff = imag(allbmats2[1]) - imag(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("i1", idiff[ni, :]/0.001)
    println("adf", adf_i[1, 1, :])
    println("i dn", idiff[ni+Nx, :]/0.001)
    println("adf dn", adf_i[1, 2, :])
    #
    #更改位置i上的a
    rax[ni] -= 0.001
    iax[ni] += 0.001
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats2, ss = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    #
    rdiff = real(allbmats2[1]) - real(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("r1", rdiff[ni, :]/0.001)
    println("adf", idf_r[1, 1, :])
    println("r dn",rdiff[ni+Nx, :]/0.001)
    println("adf dn", idf_r[1, 2, :])
    #
    idiff = imag(allbmats2[1]) - imag(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("i1", idiff[ni, :]/0.001)
    println("adf", idf_i[1, 1, :])
    println("i dn", idiff[ni+Nx, :]/0.001)
    println("adf dn", idf_i[1, 2, :])
    #更改位置i上的b
    iax[ni] -= 0.001
    rbx[ni] += 0.001
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats2, ss = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    #
    rdiff = real(allbmats2[1]) - real(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("r1", rdiff[ni, :]/0.001)
    println("adf", bdf_r[1, 1, :])
    println("r dn",rdiff[ni+Nx, :]/0.001)
    println("adf dn", bdf_r[1, 2, :])
    #
    idiff = imag(allbmats2[1]) - imag(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("i1", idiff[ni, :]/0.001)
    println("adf", bdf_i[1, 1, :])
    println("i dn", idiff[ni+Nx, :]/0.001)
    println("adf dn", bdf_i[1, 2, :])
    #更改位置i上的bi
    rbx[ni] -= 0.001
    ibx[ni] += 0.001
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats2, ss = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    #
    rdiff = real(allbmats2[1]) - real(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("r1", rdiff[ni, :]/0.001)
    println("adf", ibdf_r[1, 1, :])
    println("r dn",rdiff[ni+Nx, :]/0.001)
    println("adf dn", ibdf_r[1, 2, :])
    #
    idiff = imag(allbmats2[1]) - imag(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("i1", idiff[ni, :]/0.001)
    println("adf", ibdf_i[1, 1, :])
    println("i dn", idiff[ni+Nx, :]/0.001)
    println("adf dn", ibdf_i[1, 2, :])
end


function run3()
    L = 4
    Nx = L^2
    dt = 0.1
    #
    tp = 0.25
    hk = lattice_tprim_square(ComplexF64, L, -1.0+0.0im, tp+0.0im)
    #
    Ui = 4*ones(Nx)
    rax = @. 2acosh(exp(Ui*dt/2))/Ui/dt
    #rax += rand(Nx)
    iax = 0.1*rand(Nx)
    #iax += rand(Nx)
    rbx = 0.5*rand(Nx)
    ibx = 0.1*rand(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats1, ss = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    logS1 = sgn_scratch(ss)
    pval = @. rbx + im*ibx
    for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
        if hscfg[bi, xi] == -1
            logS1 += log(abs(pval[xi]))
        else
            logS1 += log(abs(complex(1.0, 0.0)-pval[xi]))
        end
    end; end
    tapemap = pbpa_calc_tapemap(sp, rax, iax, rbx, ibx)
    drax, diax, drbx, dibx = meas_grada(ss, allbmats1, sp, hscfg, rax, iax, rbx, ibx; tapemap=tapemap)
    #
    ni = 3
    rax[ni] += 0.001
    sslen, bmats, allbmats2, ss2 = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    logS2 = sgn_scratch(ss2)
    pval = @. rbx + im*ibx
    for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
        if hscfg[bi, xi] == -1
            logS2 += log(abs(pval[xi]))
        else
            logS2 += log(abs(complex(1.0, 0.0)-pval[xi]))
        end
    end; end
    #
    #tapemap = pbpa_calc_tapemap(sp, rax, iax, rbx, ibx)
    #drax, diax, drbx, dibx = meas_grada(ss, allbmats1, sp, hscfg, rax, iax, rbx, ibx; tapemap=tapemap)
    println("ad ", drax[ni])
    println("nd ", (-logS2+logS1)/0.001)
    #
    rax[ni] -= 0.001
    iax[ni] += 0.001
    sslen, bmats, allbmats2, ss2 = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    logS2 = sgn_scratch(ss2)
    pval = @. rbx + im*ibx
    for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
        if hscfg[bi, xi] == -1
            logS2 += log(abs(pval[xi]))
        else
            logS2 += log(abs(complex(1.0, 0.0)-pval[xi]))
        end
    end; end
    #
    #tapemap = pbpa_calc_tapemap(sp, rax, iax, rbx, ibx)
    #drax, diax, drbx, dibx = meas_grada(ss, allbmats1, sp, hscfg, rax, iax, rbx, ibx; tapemap=tapemap)
    println("ad ", diax[ni])
    println("nd ", (-logS2+logS1)/0.001)
    #
    iax[ni] -= 0.001
    rbx[ni] += 0.001
    sslen, bmats, allbmats2, ss2 = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    logS2 = sgn_scratch(ss2)
    #如果注释掉这个pval的更新，和ignorew=True应该一致
    pval = @. rbx + im*ibx
    for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
        if hscfg[bi, xi] == -1
            logS2 += log(abs(pval[xi]))
        else
            logS2 += log(abs(complex(1.0, 0.0)-pval[xi]))
        end
    end; end
    #
    #tapemap = pbpa_calc_tapemap(sp, rax, iax, rbx, ibx)
    #drax, diax, drbx, dibx = meas_grada(ss, allbmats1, sp, hscfg, rax, iax, rbx, ibx; tapemap=tapemap)
    println("ad ", drbx[ni])
    println("nd ", (-logS2+logS1)/0.001)
    #
    rbx[ni] -= 0.001
    ibx[ni] += 0.0001
    sslen, bmats, allbmats2, ss2 = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    logS2 = sgn_scratch(ss2)
    #如果注释掉这个pval的更新，和ignorew=True应该一致
    pval = @. rbx + im*ibx
    for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
        if hscfg[bi, xi] == -1
            logS2 += log(abs(pval[xi]))
        else
            logS2 += log(abs(complex(1.0, 0.0)-pval[xi]))
        end
    end; end
    #
    #tapemap = pbpa_calc_tapemap(sp, rax, iax, rbx, ibx)
    #drax, diax, drbx, dibx = meas_grada(ss, allbmats1, sp, hscfg, rax, iax, rbx, ibx; tapemap=tapemap)
    println("ad ", dibx[ni])
    println("nd ", (-logS2+logS1)/0.0001)
end


function run4()
    L = 4
    Nx = L^2
    dt = 0.1
    #
    tp = 0.25
    hk = lattice_tprim_square(ComplexF64, L, -1.0+0.0im, tp+0.0im)
    #
    Ui = 4*ones(Nx)
    rax = @. 2acosh(exp(Ui*dt/2))/Ui/dt
    #rax += rand(Nx)
    iax = zeros(Nx)
    #iax += rand(Nx)
    rbx = 0.5*ones(Nx)
    ibx = zeros(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    hscfg2 = copy(hscfg)
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats1, ss = initialize_SS_Gauge1(Nt, Ng, sp, hscfg, rax, iax, rbx, ibx)
    logS1 = sgn_scratch(ss)
    println(logS1)
    Random.seed!(1234)
    _, _, _, ss, gf, ph = dqmc_step_Gauge1(
        Nt, Ng, sp, hscfg, rax, iax, rbx, ibx, sslen, bmats, allbmats1, ss
    )
    logS1 = sgn_scratch(ss)
    println(logS1)
    #
    sslen, bmats, allbmats2, ss2 = initialize_SS(Nt, Ng, sp, hscfg2, zeros(Nx), zeros(Nx), ones(Nx))
    logS2 = sgn_scratch(ss2)
    println(logS2)
    Random.seed!(1234)
    _, _, _, ss2, gf, ph = dqmc_step(
        Nt, Ng, sp, hscfg2, zeros(Nx), zeros(Nx), ones(Nx), sslen, bmats, allbmats2, ss2
    )
    logS2 = sgn_scratch(ss2)
    println(logS2)
end

#run()
#run2()
run3()
#run4()
